"""Tests for migration_utils."""

from keras.initializers import GlorotUniform as V2GlorotUniform
from keras.legacy_tf_layers import migration_utils
import tensorflow as tf


class DeterministicRandomTestToolTest(tf.test.TestCase):

  def test_constant_mode_no_seed(self):
    """Test random tensor generation consistancy in constant mode.

    Verify that the random tensor generated without using the seed is
    consistant between graph and eager mode
    """

    # Generate three random tensors to show how the stateful random number
    # generation and glorot_uniform_initializer match between sessions and
    # eager execution.
    random_tool = migration_utils.DeterministicRandomTestTool()
    with random_tool.scope():
      graph = tf.Graph()
      with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
        a = tf.compat.v1.random.uniform(shape=(3, 1))
        # adding additional computation/ops to the graph and ensuring consistant
        # random number generation
        a = a * 3
        b = tf.compat.v1.random.uniform(shape=(3, 3))
        b = b * 3
        c = tf.compat.v1.random.uniform(shape=(3, 3))
        c = c * 3
        d = tf.compat.v1.glorot_uniform_initializer()(
            shape=(6, 6), dtype=tf.float32)
        graph_a, graph_b, graph_c, graph_d = sess.run([a, b, c, d])

      a = tf.compat.v2.random.uniform(shape=(3, 1))
      a = a * 3
      b = tf.compat.v2.random.uniform(shape=(3, 3))
      b = b * 3
      c = tf.compat.v2.random.uniform(shape=(3, 3))
      c = c * 3
      d = V2GlorotUniform()(shape=(6, 6), dtype=tf.float32)
    # validate that the generated random tensors match
    self.assertAllClose(graph_a, a)
    self.assertAllClose(graph_b, b)
    self.assertAllClose(graph_c, c)
    self.assertAllClose(graph_d, d)
    # In constant mode, because b and c were generated with the same seed within
    # the same scope and have the same shape, they will have exactly the same
    # values.
    # validate that b and c are the same, also graph_b and graph_c
    self.assertAllClose(b, c)
    self.assertAllClose(graph_b, graph_c)

  def test_constant_mode_seed_argument(self):
    """Test random tensor generation consistancy in constant mode.

    Verify that the random tensor generated by setting the global seeed
    in the args is consistant between graph and eager mode.
    """
    random_tool = migration_utils.DeterministicRandomTestTool()
    with random_tool.scope():
      graph = tf.Graph()
      with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
        # adding additional computation/ops to the graph and ensuring consistant
        # random number generation
        a = tf.compat.v1.random.uniform(shape=(3, 1), seed=1234)
        a = a * 3
        b = tf.compat.v1.random.uniform(shape=(3, 3), seed=1234)
        b = b * 3
        c = tf.compat.v1.glorot_uniform_initializer(seed=1234)(
            shape=(6, 6), dtype=tf.float32)
        graph_a, graph_b, graph_c = sess.run([a, b, c])
      a = tf.compat.v2.random.uniform(shape=(3, 1), seed=1234)
      a = a * 3
      b = tf.compat.v2.random.uniform(shape=(3, 3), seed=1234)
      b = b * 3
      c = V2GlorotUniform(seed=1234)(shape=(6, 6), dtype=tf.float32)

    # validate that the generated random tensors match
    self.assertAllClose(graph_a, a)
    self.assertAllClose(graph_b, b)
    self.assertAllClose(graph_c, c)

  def test_num_rand_ops(self):
    """Test random tensor generation consistancy in num_random_ops mode.

    Verify that the random tensor generated without using the seed is
    consistant between graph and eager mode.
    Random tensor generated should be different based on random ops ordering
    """
    random_tool = migration_utils.DeterministicRandomTestTool(
        mode="num_random_ops")
    with random_tool.scope():
      graph = tf.Graph()
      with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
        # adding additional computation/ops to the graph and ensuring consistant
        # random number generation
        a = tf.compat.v1.random.uniform(shape=(3, 1))
        a = a * 3
        b = tf.compat.v1.random.uniform(shape=(3, 3))
        b = b * 3
        c = tf.compat.v1.random.uniform(shape=(3, 3))
        c = c * 3
        d = tf.compat.v1.glorot_uniform_initializer()(
            shape=(6, 6), dtype=tf.float32)
        graph_a, graph_b, graph_c, graph_d = sess.run([a, b, c, d])

    random_tool = migration_utils.DeterministicRandomTestTool(
        mode="num_random_ops")
    with random_tool.scope():
      a = tf.compat.v2.random.uniform(shape=(3, 1))
      a = a * 3
      b = tf.compat.v2.random.uniform(shape=(3, 3))
      b = b * 3
      c = tf.compat.v2.random.uniform(shape=(3, 3))
      c = c * 3
      d = V2GlorotUniform()(shape=(6, 6), dtype=tf.float32)
    # validate that the generated random tensors match
    self.assertAllClose(graph_a, a)
    self.assertAllClose(graph_b, b)
    self.assertAllClose(graph_c, c)
    self.assertAllClose(graph_d, d)
    # validate that the tensors differ based on ops ordering
    self.assertNotAllClose(b, c)
    self.assertNotAllClose(graph_b, graph_c)

  def test_num_rand_ops_program_order(self):
    """Test random tensor generation consistancy in num_random_ops mode.

    validate that in this mode random number generation is sensitive to program
    order, so the generated random tesnors should not match.
    """
    random_tool = migration_utils.DeterministicRandomTestTool(
        mode="num_random_ops")
    with random_tool.scope():
      a = tf.random.uniform(shape=(3, 1))
      # adding additional computation/ops to the graph and ensuring consistant
      # random number generation
      a = a * 3
      b = tf.random.uniform(shape=(3, 3))
      b = b * 3

    random_tool = migration_utils.DeterministicRandomTestTool(
        mode="num_random_ops")
    with random_tool.scope():
      b_prime = tf.random.uniform(shape=(3, 3))
      # adding additional computation/ops to the graph and ensuring consistant
      # random number generation
      b_prime = b_prime * 3
      a_prime = tf.random.uniform(shape=(3, 1))
      a_prime = a_prime * 3
    # validate that the tensors are different
    self.assertNotAllClose(a, a_prime)
    self.assertNotAllClose(b, b_prime)

  def test_num_rand_ops_operation_seed(self):
    """Test random tensor generation consistancy in num_random_ops mode.

    validate if  random number generation match across two different program
    orders.
    """
    random_tool = migration_utils.DeterministicRandomTestTool(
        mode="num_random_ops")
    with random_tool.scope():
      # operation seed = 0
      a = tf.random.uniform(shape=(3, 1))
      a = a * 3
      # operation seed = 1
      b = tf.random.uniform(shape=(3, 3))
      b = b * 3

    random_tool = migration_utils.DeterministicRandomTestTool(
        mode="num_random_ops")
    with random_tool.scope():
      random_tool.operation_seed = 1
      b_prime = tf.random.uniform(shape=(3, 3))
      b_prime = b_prime * 3
      random_tool.operation_seed = 0
      a_prime = tf.random.uniform(shape=(3, 1))
      a_prime = a_prime * 3

    self.assertAllClose(a, a_prime)
    self.assertAllClose(b, b_prime)

  def test_num_rand_ops_disallow_repeated_ops_seed(self):
    """Test random tensor generation consistancy in num_random_ops mode.

    validate if  DeterministicRandomTestTool disallows reusing already-used
    operation seeds.
    """
    random_tool = migration_utils.DeterministicRandomTestTool(
        mode="num_random_ops")
    with random_tool.scope():
      random_tool.operation_seed = 1
      b_prime = tf.random.uniform(shape=(3, 3))
      b_prime = b_prime * 3
      random_tool.operation_seed = 0
      a_prime = tf.random.uniform(shape=(3, 1))
      a_prime = a_prime * 3
      error_string = "An exception should have been raised before this"
      error_raised = "An exception should have been raised before this"
      try:
        c = tf.random.uniform(shape=(3, 1))
        raise RuntimeError(error_string)

      except ValueError as err:
        err_raised = err

      self.assertNotEqual(err_raised, error_string)


if __name__ == "__main__":
  tf.test.main()

